Gaussian Process Classification

Preliminary steps

Loading necessary packages

using Plots
using HTTP, CSV
using DataFrames: DataFrame
using AugmentedGaussianProcesses
using MLDataUtils

Loading the banana dataset from OpenML

data = HTTP.get("https://www.openml.org/data/get_csv/1586217/phpwRjVjk")
data = CSV.read(data.body, DataFrame)
data.Class[data.Class .== 2] .= -1
data = Matrix(data)
X = data[:, 1:2]
Y = Int.(data[:, end]);
(X_train, y_train), (X_test, y_test) = splitobs((X, Y), 0.5, ObsDim.First())
(([1.14 -0.114; -1.52 -1.15; … ; -0.0708 0.439; 0.177 -1.37], [1, -1, 1, -1, -1, -1, 1, -1, -1, -1  …  1, -1, 1, -1, -1, -1, 1, -1, -1, 1]), ([-0.135 0.136; -0.288 0.385; … ; 0.769 0.772; -0.255 -0.142], [-1, -1, -1, 1, 1, -1, -1, 1, -1, 1  …  -1, -1, -1, -1, -1, -1, -1, -1, 1, -1]))

We create a function to visualize the data

function plot_data(X, Y; size=(300, 500))
    return Plots.scatter(
        eachcol(X)...; group=Y, alpha=0.2, markerstrokewidth=0.0, lab="", size=size
    )
end
plot_data(X, Y; size=(500, 500))

Run sparse classification with increasing number of inducing points

Ms = [4, 8, 16, 32, 64]
models = Vector{AbstractGPModel}(undef, length(Ms) + 1)
kernel = SqExponentialKernel() ∘ ScaleTransform(1.0)
for (i, num_inducing) in enumerate(Ms)
    @info "Training with $(num_inducing) points"
    m = SVGP(
        kernel,
        LogisticLikelihood(),
        AnalyticVI(),
        inducingpoints(KmeansAlg(num_inducing), X);
        optimiser=false,
        Zoptimiser=false,
    )
    @time train!(m, X_train, y_train, 20)
    models[i] = m
end
[ Info: Training with 4 points
  0.016693 seconds (3.00 k allocations: 8.588 MiB)
[ Info: Training with 8 points
  0.013838 seconds (3.01 k allocations: 13.972 MiB)
[ Info: Training with 16 points
  0.020187 seconds (3.04 k allocations: 24.884 MiB)
[ Info: Training with 32 points
  0.036689 seconds (3.09 k allocations: 47.315 MiB)
[ Info: Training with 64 points
  0.129250 seconds (3.39 k allocations: 94.619 MiB, 46.72% gc time)

Running the full model

@info "Running full model"
mfull = VGP(X_train, y_train, kernel, LogisticLikelihood(), AnalyticVI(); optimiser=false)
@time train!(mfull, 5)
models[end] = mfull
Variational Gaussian Process with a BernoulliLikelihood{GPLikelihoods.LogisticLink}(GPLikelihoods.LogisticLink(LogExpFunctions.logistic)) infered by Analytic Variational Inference 

We create a prediction and plot function on a grid

function compute_grid(model, n_grid=50)
    mins = [-3.25, -2.85]
    maxs = [3.65, 3.4]
    x_lin = range(mins[1], maxs[1]; length=n_grid)
    y_lin = range(mins[2], maxs[2]; length=n_grid)
    x_grid = Iterators.product(x_lin, y_lin)
    y_grid, _ = proba_y(model, vec(collect.(x_grid)))
    return y_grid, x_lin, y_lin
end

function plot_model(model, X, Y, title=nothing; size=(300, 500))
    n_grid = 50
    y_pred, x_lin, y_lin = compute_grid(model, n_grid)
    title = if isnothing(title)
        (model isa SVGP ? "M = $(AGP.dim(model[1]))" : "full")
    else
        title
    end
    p = plot_data(X, Y; size=size)
    Plots.contour!(
        p,
        x_lin,
        y_lin,
        reshape(y_pred, n_grid, n_grid)';
        cbar=false,
        levels=[0.5],
        fill=false,
        color=:black,
        linewidth=2.0,
        title=title,
    )
    if model isa SVGP
        Plots.scatter!(
            p, eachrow(hcat(AGP.Zview(model[1])...))...; msize=2.0, color="black", lab=""
        )
    end
    return p
end;

Now run the prediction for every model and visualize the differences

Plots.plot(
    plot_model.(models, Ref(X), Ref(Y))...; layout=(1, length(models)), size=(1000, 200)
)

Bayesian SVM vs Logistic

We now create a model with the Bayesian SVM likelihood

mbsvm = VGP(X_train, y_train, kernel, BayesianSVM(), AnalyticVI(); optimiser=false)
@time train!(mbsvm, 5)
(Variational Gaussian Process with a BernoulliLikelihood{AugmentedGaussianProcesses.SVMLink}(AugmentedGaussianProcesses.SVMLink()) infered by Analytic Variational Inference , (local_vars = (c = [0.19066459123793808, 0.2316454066533993, 0.0031544471158863247, 0.13783608431929248, 0.23493768647556923, 4.754524325433761, 1.6762565726960272, 1.4822238505113354, 0.45473690730484545, 0.07199295745870501  …  0.3674542315138281, 0.46906021066906367, 0.10311036630751802, 0.820139046182534, 3.2678129506899065, 0.6793569622948201, 5.001778909691415, 0.4947182427870569, 1.2812846432152984, 2.9561752313139205], θ = [2.29015552640959, 2.077725421646932, 17.80485226800952, 2.6935096539564287, 2.0631160426355173, 0.4586131076254875, 0.7723777475453046, 0.8213780661294696, 1.482927428694241, 3.726962239772856  …  1.6496749262220423, 1.4601104286319262, 3.1142167420131313, 1.1042216441045376, 0.5531862862463869, 1.2132519128677302, 0.44713406146199275, 1.42174279924846, 0.8834402651824085, 0.5816140819628853]), opt_state = (NamedTuple(),), hyperopt_state = (NamedTuple(),), kernel_matrices = ((K = LinearAlgebra.Cholesky{Float64, Matrix{Float64}}([1.0000499987500624 0.017000746952647305 … 0.4123128729363635 0.2857887156916333; 0.017001596968745064 0.9999054828347788 … 0.09200744609514472 0.22644779606578086; … ; 0.4123334880646449 0.0990083766302707 … 0.010077146944507058 -1.6990382020865232e-6; 0.28580300477019976 0.23128501449942165 … 0.18882343252913514 0.010065643970489872], 'U', 0),),)))

And compare it with the Logistic likelihood

Plots.plot(
    plot_model.(
        [models[end], mbsvm], Ref(X), Ref(Y), ["Logistic", "BSVM"]; size=(500, 250)
    )...;
    layout=(1, 2),
)

This page was generated using Literate.jl.